#!/usr/bin/env python

import json
import sys
import jinja2
import pprint
import napalm
import paramiko.ssh_exception
import jnpr.junos.exception
import time

class RouterList:
    def __init__(self):
        self.routers = []

    def append(self, router):
        self.routers.append(router)

    def get(self, name):
        for router in self.routers:
            if router.name == name:
                return router

    def list(self):
        return self.routers

class Router:
    def __init__(self):
        self.id = 0
        self.name = ""
        self.type = ""
        self.template = None
        self.config = None
        self.links = []
        self.attrs = {}
        self.device = None
        self.loaded = False

    def connect(self, wait_for_boot=False):
        if self.type == "xrv":
            driver = napalm.get_network_driver("iosxr")
        elif self.type == "vmx":
            driver = napalm.get_network_driver("junos")
        elif self.type == "csr":
            driver = napalm.get_network_driver("ios")
        else:
            raise Exception("Unknown device type: {}".format(self.type))

        self.device = driver(self.get_address(), "vrnetlab", "VR-netlab9")
        TIMEOUT = 60*15
        TIMER = 0

        if wait_for_boot:
            while True:
                try:
                    if TIMER >= TIMEOUT:
                        # Timeout if router hasn't started in 15 minutes.
                        raise TimeoutError("Timed out wating for router {} to start".format(self.name))
                    self.device.open()
                    return self.device
                except paramiko.ssh_exception.SSHException:
                    # xrv and csr throws this exception when connection
                    # failed, wait a bit and retry.
                    TIMER += 2
                    time.sleep(2)
                except jnpr.junos.exception.ConnectError:
                    # vmx throws this exception when connection failed
                    # wait a bit and retry
                    TIMER += 2
                    time.sleep(2)
        else:
            self.device.open()
            return self.device

    def load_merge(self):
        if not self.device:
            self.connect()
        self.loaded = True
        return self.device.load_merge_candidate(config=self.config)

    def commit_config(self):
        if not self.loaded:
            raise Exception("Please run load_merge first")

        return self.device.commit_config()

    def discard_config(self):
        if not self.loaded:
            raise Exception("Please run load_merge first")

        return self.device.discard_config()

    def compare_config(self):
        if not self.loaded:
            raise Exception("Please run load_merge first")

        return self.device.compare_config()

    def get_address(self):
        import subprocess
        import socket

        try:
            return socket.gethostbyname(self.name)
        except socket.gaierror:
            # Try docker inspect if gethostbyname fails
            cmd = [ "docker", "inspect", "--format", 
                    "'{{.NetworkSettings.IPAddress}}'", self.name ]
            p = subprocess.Popen(cmd, stdout=subprocess.PIPE, cwd=".")
            return str(p.communicate()[0].strip().replace("'", ""))

    def add_attr(self, key, value):
        self.attrs[key] = value

class ConfigBootstrap:
    def __init__(self, wait_for_boot):
        self.routers = RouterList() 
        self.wait_for_boot = wait_for_boot

    def load_router(self, name, model, template, attrs):
        router = Router()
        router.id = 1
        router.name = name

        for k in attrs:
            router.add_attr(k, attrs[k])
        
        router.type = model
        router.template = template

        self.routers.append(router)
        
    def load_topology(self, topology, xr, junos, ios):
        self.routers = RouterList()

        i = 1
        for name in topology["routers"]:
            elem = topology["routers"][name]
            router = Router()

            if "id" not in elem:
                router.id = i
            else:
                router.id = elem["id"]

            for key in elem:
                if key not in ["id", "type"]:
                    router.add_attr(key, elem[key])

            router.name = name
            router.type = elem["type"]

            if router.type == "xrv":
                router.template = xr

            if router.type == "vmx":
                router.template = junos

            if router.type == "csr":
                router.template = ios

            self.routers.append(router)
            i = i + 1

        i = 1
        for link in topology["links"]:
            left = self.routers.get(link["left"]["router"])
            right = self.routers.get(link["right"]["router"])

            left.links.append({
                "interface": link["left"]["interface"],
                "numeric": link["left"]["numeric"],
                "id": max(left.id, right.id) + i,
                "octet": 1,
                "remote": {
                    "router": link["right"]["router"],
                    "interface": link["right"]["interface"],
                    "numeric": link["right"]["numeric"]
                }
            })

            right.links.append({
                "interface": link["right"]["interface"],
                "numeric": link["right"]["numeric"],
                "id": max(left.id, right.id)+i,
                "octet": 2,
                "remote": {
                    "router": link["left"]["router"],
                    "interface": link["left"]["interface"],
                    "numeric": link["left"]["numeric"]
                }
            })
            i = i + 1

    def connect(self):
        for router in self.routers.list():
            router.connect(self.wait_for_boot)

    def render_config(self):
        for router in self.routers.list():
            config = {
                "hostname": router.name, 
                "links": router.links, 
                "id": router.id
            }

            for key in router.attrs:
                config[key] = router.attrs[key]

            env = jinja2.Environment(loader=jinja2.FileSystemLoader(['./']))
            template = env.get_template(router.template)
            router.config = template.render(config)
            router.load_merge()

    def apply_config(self):
        for router in self.routers.list():
            router.commit_config()

    def diff_config(self):
        for router in self.routers.list():
            print(router.compare_config())
            router.discard_config()

if __name__ == '__main__':
    import argparse
    import os

    parser = argparse.ArgumentParser()
    parser.add_argument("--topo", help="Low-level topology file from topomachine")
    parser.add_argument("--xr", help="IOS-XR Template")
    parser.add_argument("--junos", help="JunOS template")
    parser.add_argument("--ios", help="IOS template")
    parser.add_argument("--router", help="Name of your virtual router to configure, don't use together with --topo")
    parser.add_argument("--config", help="Template to apply to your router specified with --router")
    parser.add_argument("--type", help="Type of router specified with --router")
    parser.add_argument("--attr", help="Add extra attribute exposed in your config template", action="append")
    parser.add_argument("--wait-for-boot", help="Retry connection until successful", default=False, action="store_true")
    parser.add_argument("--run", help="Apply configuration", default=False, action="store_true")
    parser.add_argument("--diff", help="Display configuration diff but don't commit", default=False, action="store_true")
    args = parser.parse_args()

    cb = ConfigBootstrap(args.wait_for_boot)

    if args.topo:
        if args.router:
            print("You sould use either --router or --topo")
            sys.exit(1)

        if not os.path.isfile(args.topo):
            print("Topology file doesn't exist")
            sys.exit(1)

        if args.xr and not os.path.isfile(args.xr):
            print("IOS-XR template doesn't exist")
            sys.exit(1)

        if args.junos and not os.path.isfile(args.junos):
            print("JunOS template doesn't exist")
            sys.exit(1)

        if args.ios and not os.path.isfile(args.ios):
            print("IOS template doesn't exist")
            sys.exit(1)

        input_file = open(args.topo, "r")
        topology = json.loads(input_file.read())
        input_file.close()

        cb.load_topology(topology, args.xr, args.junos, args.ios)
        cb.connect()
        if args.run or args.diff:
            cb.render_config()

        if args.run:
            cb.apply_config()

        if args.diff:
            cb.diff_config()


    if args.router:
        if not args.config or not os.path.isfile(args.config):
            print("Configuration template doesn't exist")
            sys.exit(1)
        if args.type not in [ "vmx", "csr", "xrv"]:
            print("Invalid router type {}".format(args.type))
            sys.exit(1)

        attrs = {}
        if args.attr:
            try:
                for entry in args.attr:
                    pcs = entry.split("=")
                    attrs[pcs[0]] = pcs[1]
            except:
                print("Failed to parse extra attributes")
                sys.exit(1)

        cb.load_router(args.router, args.type, args.config, attrs)
        cb.connect()

        if args.run or args.diff:
            cb.render_config()

        if args.run:
            cb.apply_config()

        if args.diff:
            cb.diff_config()
